# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#      http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from collections import OrderedDict

import numpy as np
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms import ToPILImage
from tqdm import tqdm

import nncf


def get_files(folder, name_filter=None, extension_filter=None):
    """Helper function that returns the list of files in a specified folder
    with a specified extension.

    Keyword arguments:
    - folder (``string``): The path to a folder.
    - name_filter (```string``, optional): The returned files must contain
    this substring in their filename. Default: None; files are not filtered.
    - extension_filter (``string``, optional): The desired file extension.
    Default: None; files are not filtered

    """
    if not os.path.isdir(folder):
        raise nncf.InvalidPathError('"{0}" is not a folder.'.format(folder))

    # Filename filter: if not specified don't filter (condition always true);
    # otherwise, use a lambda expression to filter out files that do not
    # contain "name_filter"
    if name_filter is None:
        # This looks hackish...there is probably a better way
        name_cond = lambda filename: True
    else:
        name_cond = lambda filename: name_filter in filename

    # Extension filter: if not specified don't filter (condition always true);
    # otherwise, use a lambda expression to filter out files whose extension
    # is not "extension_filter"
    if extension_filter is None:
        # This looks hackish...there is probably a better way
        ext_cond = lambda filename: True
    else:
        ext_cond = lambda filename: filename.endswith(extension_filter)

    filtered_files = []

    # Explore the directory tree to get files that contain "name_filter" and
    # with extension "extension_filter"
    for path, _, files in os.walk(folder):
        files.sort()
        for file in files:
            if name_cond(file) and ext_cond(file):
                full_path = os.path.join(path, file)
                filtered_files.append(full_path)

    return filtered_files


def pil_loader(data_path, label_path):
    """Loads a sample and label image given their path as PIL images.

    Keyword arguments:
    - data_path (``string``): The filepath to the image.
    - label_path (``string``): The filepath to the ground-truth image.

    Returns the image and the label as PIL images.

    """
    data = Image.open(data_path)
    label = Image.open(label_path)

    return data, label


def remap(image, old_values, new_values):
    assert isinstance(image, (Image.Image, np.ndarray)), "image must be of type PIL.Image or numpy.ndarray"
    assert isinstance(new_values, tuple), "new_values must be of type tuple"
    assert isinstance(old_values, tuple), "old_values must be of type tuple"
    assert len(new_values) == len(old_values), "new_values and old_values must have the same length"

    # If image is a PIL.Image convert it to a numpy array
    if isinstance(image, Image.Image):
        image = np.array(image)

    # Replace old values by the new ones
    tmp = np.zeros_like(image)
    for old, new in zip(old_values, new_values):
        # Since tmp is already initialized as zeros we can skip new values
        # equal to 0
        if new != 0:
            # See Pylint issue #2721

            tmp[image == old] = new

    return Image.fromarray(tmp)


def enet_weighing(dataloader, num_classes, c=1.02):
    """Computes class weights as described in the ENet paper:

        w_class = 1 / (ln(c + p_class)),

    where c is usually 1.02 and p_class is the propensity score of that
    class:

        propensity_score = freq_class / total_pixels.

    References: https://arxiv.org/abs/1606.02147

    Keyword arguments:
    - dataloader (``data.Dataloader``): A data loader to iterate over the
    dataset.
    - num_classes (``int``): The number of classes.
    - c (``int``, optional): AN additional hyper-parameter which restricts
    the interval of values for the weights. Default: 1.02.

    """
    class_count = np.zeros(num_classes)
    total = np.zeros(num_classes)
    for _, label in tqdm(dataloader):
        label = label.cpu().numpy()

        # Flatten label
        flat_label = label.flatten()

        # Ignore out-of-bounds target labels
        valid_indices = np.where((flat_label >= 0) & (flat_label < num_classes))
        flat_label = flat_label[valid_indices]

        # Sum up the number of pixels of each class and the total pixel
        # counts for each label
        class_count += np.bincount(flat_label, minlength=num_classes)
        total += flat_label.size

    # Compute propensity score and then the weights for each class
    propensity_score = class_count / total
    class_weights = 1 / (np.log(c + propensity_score))

    return class_weights


def median_freq_balancing(dataloader, num_classes):
    """Computes class weights using median frequency balancing as described
    in https://arxiv.org/abs/1411.4734:

        w_class = median_freq / freq_class,

    where freq_class is the number of pixels of a given class divided by
    the total number of pixels in images where that class is present, and
    median_freq is the median of freq_class.

    Keyword arguments:
    - dataloader (``data.Dataloader``): A data loader to iterate over the
    dataset.
    whose weights are going to be computed.
    - num_classes (``int``): The number of classes

    """
    class_count = np.zeros(num_classes)
    total = np.zeros(num_classes)
    for _, label in tqdm(dataloader):
        label = label.cpu().numpy()

        # Flatten label
        flat_label = label.flatten()

        # Ignore out-of-bounds target labels
        valid_indices = np.where((flat_label >= 0) & (flat_label < num_classes))
        flat_label = flat_label[valid_indices]

        # Sum up the class frequencies
        bincount = np.bincount(flat_label, minlength=num_classes)

        # Create of mask of classes that exist in the label
        mask = bincount > 0
        # Multiply the mask by the pixel count. The resulting array has
        # one element for each class. The value is either 0 (if the class
        # does not exist in the label) or equal to the pixel count (if
        # the class exists in the label)
        total += mask * flat_label.size

        # Sum up the number of pixels found for each class
        class_count += bincount

    # Compute the frequency and its median

    freq = class_count / total
    freq = np.nan_to_num(freq)  # Guard against 0/0 divisions
    med = np.median(freq)
    class_weights = np.nan_to_num(med / freq)

    return class_weights


class LongTensorToRGBPIL:
    """Converts a ``torch.LongTensor`` to a ``PIL image``.

    The input is a ``torch.LongTensor`` where each pixel's value identifies the
    class.

    Keyword arguments:
    - rgb_encoding (``OrderedDict``): An ``OrderedDict`` that relates pixel
    values, class names, and class colors.

    """

    def __init__(self, rgb_encoding):
        self.rgb_encoding = rgb_encoding

    def __call__(self, tensor):
        """Performs the conversion from ``torch.LongTensor`` to a ``PIL image``

        Keyword arguments:
        - tensor (``torch.LongTensor``): the tensor to convert

        Returns:
        A ``PIL.Image``.

        """
        # Check if label_tensor is a LongTensor
        if not isinstance(tensor, torch.LongTensor):
            raise TypeError("label_tensor should be torch.LongTensor. Got {}".format(type(tensor)))
        # Check if encoding is a ordered dictionary
        if not isinstance(self.rgb_encoding, OrderedDict):
            raise TypeError("encoding should be an OrderedDict. Got {}".format(type(self.rgb_encoding)))

        # label_tensor might be an image without a channel dimension, in this
        # case unsqueeze it
        if len(tensor.size()) == 2:
            tensor.unsqueeze_(0)

        color_tensor = torch.ByteTensor(3, tensor.size(1), tensor.size(2)).fill_(0)

        for index, (_, color) in enumerate(self.rgb_encoding.items()):
            # Get a mask of elements equal to index
            mask = torch.eq(tensor, index).squeeze_()
            # Fill color_tensor with corresponding colors
            for channel, color_value in enumerate(color):
                color_tensor[channel].masked_fill_(mask, color_value)

        return ToPILImage()(color_tensor)


def batch_transform(batch, transform):
    """Applies a transform to a batch of samples.

    Keyword arguments:
    - batch (): a batch os samples
    - transform (callable): A function/transform to apply to ``batch``

    """

    # Convert the single channel label to RGB in tensor form
    # 1. torch.unbind removes the 0-dimension of "labels" and returns a tuple of
    # all slices along that dimension
    # 2. the transform is applied to each slice
    transf_slices = [transform(tensor) for tensor in torch.unbind(batch)]

    return torch.stack(transf_slices)


def imshow_batch(images, labels):
    """Displays two grids of images. The top grid displays ``images``
    and the bottom grid ``labels``

    Keyword arguments:
    - images (``Tensor``): a 4D mini-batch tensor of shape
    (B, C, H, W)
    - labels (``Tensor``): a 4D mini-batch tensor of shape
    (B, C, H, W)

    """

    # Make a grid with the images and labels and convert it to numpy
    images = torchvision.utils.make_grid(images).numpy()
    labels = torchvision.utils.make_grid(labels).numpy()

    import matplotlib.pyplot as plt

    _, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 7))
    ax1.imshow(np.transpose(images, (1, 2, 0)))
    ax2.imshow(np.transpose(labels, (1, 2, 0)))

    plt.show()


def label_to_color(label, class_encoding):
    label_to_rgb = T.Compose([LongTensorToRGBPIL(class_encoding), T.ToTensor()])
    return batch_transform(label.cpu(), label_to_rgb)


def color_to_label(color_labels: Image, class_encoding: OrderedDict):
    color_labels = np.array(color_labels.convert("RGB"))

    labels = np.zeros((color_labels.shape[0], color_labels.shape[1]), dtype=np.int64)

    red = color_labels[..., 0]
    green = color_labels[..., 1]
    blue = color_labels[..., 2]

    for index, (_, color) in enumerate(class_encoding.items()):
        target_area = (red == color[0]) & (green == color[1]) & (blue == color[2])
        labels[target_area] = index

    labels = Image.fromarray(labels.astype(np.uint8), mode="L")
    return labels


def show_ground_truth_vs_prediction(images, gt_labels, color_predictions, class_encoding):
    """Displays three grids of images. The top grid displays ``images``
    the middle grid - ``gt_labels`` and the bottom grid - ``labels``

    Keyword arguments:
    - images (``Tensor``): a 4D mini-batch tensor of shape
    (B, C, H, W)
    - gt_labels (``Tensor``): a 4D mini-batch tensor of shape
    (B, C, H, W)
    - labels (``Tensor``): a 4D mini-batch tensor of shape
    (B, C, H, W)

    """

    import matplotlib.pyplot as plt

    # Make a grid with the images and labels and convert it to numpy
    images = torchvision.utils.make_grid(images).numpy()
    color_predictions = torchvision.utils.make_grid(color_predictions).numpy()

    color_gt_labels = label_to_color(gt_labels, class_encoding)
    color_gt_labels = torchvision.utils.make_grid(color_gt_labels).numpy()

    plt.subplots(2, 3, figsize=(15, 7))
    ax1 = plt.subplot2grid((3, 3), (0, 0), rowspan=2)
    ax2 = plt.subplot2grid((3, 3), (0, 1), rowspan=2)
    ax3 = plt.subplot2grid((3, 3), (0, 2), rowspan=2)
    ax4 = plt.subplot2grid((3, 3), (2, 0), colspan=3)

    ax1.imshow(np.transpose(images, (1, 2, 0)))
    ax2.imshow(np.transpose(color_gt_labels, (1, 2, 0)))
    ax3.imshow(np.transpose(color_predictions, (1, 2, 0)))

    color_names = [k for k, _ in class_encoding.items()]
    colors = [(v[0] / 255.0, v[1] / 255.0, v[2] / 255.0, 1.0) for _, v in class_encoding.items()]
    ones = np.ones(len(colors))
    ax4.bar(range(0, len(ones)), ones, color=colors, tick_label=color_names)

    plt.show()


def downsample_labels(labels, target_size=None, downsample_factor=None):
    H = labels.size()[1]
    W = labels.size()[2]
    if target_size is None and downsample_factor is None:
        raise ValueError("Either target_size or downsample_factor must be specified")
    if target_size is not None and downsample_factor is not None:
        raise ValueError("Only one of the target_size and downsample_factor must be specified")

    if downsample_factor is None:
        h = target_size[0]
        w = target_size[1]
    else:
        h = H // downsample_factor
        w = W // downsample_factor
    ih = torch.linspace(0, H - 1, h).long()
    iw = torch.linspace(0, W - 1, w).long()

    return labels[:, ih[:, None], iw]


def cat_list(images, fill_value=0):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
    batch_shape = (len(images),) + max_size
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
    for img, pad_img in zip(images, batched_imgs):
        pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
    return batched_imgs


def collate_fn(batch):
    images, targets = list(zip(*batch))
    batched_imgs = cat_list(images, fill_value=0)
    batched_targets = cat_list(targets, fill_value=255)
    return batched_imgs, batched_targets
